import torch
import torch.nn as nn
import torch.nn.functional as F
from retag_utils import ToyGraphBase, Propagation, FewShotBase, TaskDecoder, ring_contrastive_loss,ring_contrastive_loss_new
from torch_geometric.nn import global_mean_pool

class ReTAG(nn.Module):
    def __init__(self, pretrain_model, resource_dataset, feture_size, num_class, emb_size,
                 finetune=True, noise_finetune=False) -> None:
        super(ReTAG, self).__init__()
        self.emb_size = emb_size
        self.num_class = num_class
        self.pretrain_model = pretrain_model
        self.ring_proj = nn.Sequential(
            nn.Linear(emb_size, emb_size * 2),
            nn.ReLU(),
            nn.Linear(emb_size * 2, emb_size)
        )

        self.ring_weight = 0
        self.label_weight = 0.6
        self.finetune = finetune
        self.noise_finetune = noise_finetune
        if self.noise_finetune:
            assert self.finetune

        self.query_graph_hop = 1

        self.toy_graph_base = ToyGraphBase(pretrain_model, num_class, emb_size, self.query_graph_hop)
        self.toy_graph_base.build_toy_graph(resource_dataset)
        self.toy_graph_base.noise_std = 0.05
        decoder_input_dim = emb_size + 2*emb_size + emb_size  # query_emb_pool + rag_embedding + ring_feat
        self.decoder = TaskDecoder(decoder_input_dim, 512, num_class)
        self.reset_parameters()

        self.toy_graph_base.show()
        self.fewshot_base = FewShotBase(resource_dataset.name, num_class, pretrain_model)

    def reset_parameters(self):
        self.decoder.reset_parameters()

    def forward(self, features, adj, complex_batch=None, batch=None, ptr=None, return_ring_loss=False):
        node_emb, _ = self.pretrain_model.embed(features, adj)

        if batch is not None:
            graph_emb = global_mean_pool(node_emb, batch)
        elif ptr is not None:
            graph_emb = [node_emb[ptr[i]:ptr[i + 1]].mean(dim=0) for i in range(len(ptr) - 1)]
            graph_emb = torch.stack(graph_emb, dim=0)
        else:
            graph_emb = node_emb.mean(dim=0, keepdim=True)

        ring_loss = torch.tensor(0.0, device=node_emb.device)
        ring_means = []

        if batch is not None:
            num_graphs = batch.max().item() + 1
            for i in range(num_graphs):
                node_emb_i = node_emb[batch == i]
                complex_obj = complex_batch[i]

                if complex_obj is None or \
                        2 not in complex_obj.cochains or \
                        complex_obj.cochains[2] is None or \
                        complex_obj.cochains[2].boundary_index is None or \
                        complex_obj.cochains[2].boundary_index.size(1) == 0:
                    ring_mean_i = torch.zeros(self.emb_size, device=node_emb.device, requires_grad=True) * 0.01
                    single_loss = torch.tensor(0.0, device=node_emb.device, requires_grad=True)
                else:
                    single_loss, ring_mean_i = ring_contrastive_loss(
                        node_emb_i,
                        complex_obj.cochains[2].boundary_index.cuda(),
                        complex_obj.cochains[1].boundary_index.cuda()
                    )

                ring_loss += single_loss
                ring_means.append(ring_mean_i)

            ring_mean = torch.stack(ring_means, dim=0)
            ring_mean.retain_grad()
        else:
            single_loss, ring_mean = ring_contrastive_loss(
                node_emb,
                complex_batch.cochains[2].boundary_index.cuda(),
                complex_batch.cochains[1].boundary_index.cuda()
            )
            ring_loss = single_loss
            ring_mean = ring_mean.unsqueeze(0)
            ring_mean.retain_grad()
        ring_mean.retain_grad()
        if ring_mean.dim() == 1:
            ring_mean = ring_mean.unsqueeze(0)

        query_keys = torch.cat([graph_emb, ring_mean], dim=-1)

        add_noise = self.training and self.noise_finetune
        rag_embeddings, rag_labels, rag_weights = self.toy_graph_base.retrieve(
            query_keys, adj, complex_batch, ring_mean, add_noise
        )

        if self.finetune:
            attn_weights = rag_weights.unsqueeze(-1)
            rag_embedding = torch.sum(attn_weights * rag_embeddings, dim=1)
            rag_label = torch.sum(attn_weights * rag_labels, dim=1)

            query_embeddings = Propagation.aggregate_k_hop_features(adj, node_emb, self.query_graph_hop)
            if batch is not None:
                query_emb_pool = global_mean_pool(query_embeddings, batch)
            else:
                query_emb_pool = query_embeddings.mean(dim=0, keepdim=True)

            if ring_mean.shape[1] != self.emb_size:
                ring_mean = ring_mean[:, :self.emb_size]
            ring_feat = self.ring_proj(ring_mean)
            hidden_embedding = torch.cat([query_emb_pool, rag_embedding, ring_feat], dim=-1)#
            decode_label = self.decoder(hidden_embedding)
            decode_label = torch.softmax(decode_label, dim=1)
            label_logits = decode_label * (1 - self.label_weight) + rag_label * self.label_weight

            if return_ring_loss:
                return label_logits, ring_loss, ring_mean
            else:
                return label_logits

        else:
            rag_label = torch.mean(rag_labels, dim=1)
            return rag_label

    def forward_with_loss(self, features, adj, complex_batch, label, batch=None, ptr=None):
        """
        前向传播 + 构建 loss，专用于训练阶段。
        返回: (total_loss, logits, metrics_dict)
        """
        node_emb, _ = self.pretrain_model.embed(features, adj)

        # === Graph-level pooling ===
        if batch is not None:
            graph_emb = global_mean_pool(node_emb, batch)
        elif ptr is not None:
            graph_emb = [node_emb[ptr[i]:ptr[i + 1]].mean(dim=0) for i in range(len(ptr) - 1)]
            graph_emb = torch.stack(graph_emb, dim=0)
        else:
            graph_emb = node_emb.mean(dim=0, keepdim=True)

        # === Ring feature + loss ===
        ring_loss = torch.tensor(0.0, device=node_emb.device)
        ring_means = []
        if batch is not None:
            num_graphs = batch.max().item() + 1
            for i in range(num_graphs):
                node_emb_i = node_emb[batch == i]
                complex_obj = complex_batch[i]
                if complex_obj is None or \
                        2 not in complex_obj.cochains or \
                        complex_obj.cochains[2] is None or \
                        complex_obj.cochains[2].boundary_index is None or \
                        complex_obj.cochains[2].boundary_index.size(1) == 0:
                    ring_mean_i = torch.zeros(self.emb_size, device=node_emb.device, requires_grad=True)
                    single_loss = torch.tensor(0.0, device=node_emb.device, requires_grad=True)
                else:
                    single_loss, ring_mean_i = ring_contrastive_loss(
                        node_emb_i,
                        complex_obj.cochains[2].boundary_index.cuda(),
                        complex_obj.cochains[1].boundary_index.cuda()
                    )
                ring_loss += single_loss
                ring_means.append(ring_mean_i)
            ring_mean = torch.stack(ring_means, dim=0)
        else:
            single_loss, ring_mean = ring_contrastive_loss(
                node_emb,
                complex_batch.cochains[2].boundary_index.cuda(),
                complex_batch.cochains[1].boundary_index.cuda()
            )
            ring_loss = single_loss
            ring_mean = ring_mean.unsqueeze(0)
        ring_feat = self.ring_proj(ring_mean)
        query_keys = torch.cat([graph_emb, ring_mean], dim=-1)
        rag_embeddings, rag_labels, rag_weights = self.toy_graph_base.retrieve(
            query_keys, adj, complex_batch, ring_mean, add_noise=self.training and self.noise_finetune
        )
        rag_embedding = torch.sum(rag_weights.unsqueeze(-1) * rag_embeddings, dim=1)
        query_embeddings = Propagation.aggregate_k_hop_features(adj, node_emb, self.query_graph_hop)
        if batch is not None:
            query_emb_pool = global_mean_pool(query_embeddings, batch)
        else:
            query_emb_pool = query_embeddings.mean(dim=0, keepdim=True)

        decoder_input = torch.cat([query_emb_pool, rag_embedding, ring_feat], dim=-1)
        decode_logits = self.decoder(decoder_input)
        decode_probs = torch.softmax(decode_logits, dim=-1)

        cls_loss = F.cross_entropy(decode_logits, label)


        total_loss = cls_loss + self.ring_weight * ring_loss
        # total_loss = cls_loss
        return total_loss, decode_probs, {
            "cls_loss": cls_loss.item(),
            "ring_loss": ring_loss.item(),
            "total_loss": total_loss.item()
        }

